{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3101f858",
   "metadata": {},
   "source": [
    "# 使用 Relax 训练 API 训练模型\n",
    "\n",
    "社区对使用 TVM 进行模型训练的关注日益增加。作为 TVM 的新一代图级中间表示（IR），Relax 也需要满足训练模型的需求。\n",
    "\n",
    "在 Relax 上构建了完整的训练工作流，包含：\n",
    "- **基于源码转换的自动微分工具**\n",
    "- **优化器抽象** 及常见优化器实现\n",
    "- **损失函数抽象** 及常见损失函数\n",
    "- 将这些组件整合的易用 **训练器 API**\n",
    "\n",
    "这些训练 API 可满足多种需求：\n",
    "- 从零开始训练模型：利用 TVM 的编译优势加速训练过程\n",
    "- 基于 TVM 在设备端进行模型微调\n",
    "- 将训练过程部署到 TVM 支持的各种设备（如 FPGA 和树莓派）\n",
    "\n",
    "本教程将演示如何通过训练 API：\n",
    "1. 使用高层 Trainer API 从头训练模型\n",
    "2. 使用底层自动微分、优化器和损失函数 API 进行训练\n",
    "3. 深入解析自动微分系统的源码实现\n",
    "\n",
    "将使用 Fashion MNIST 数据集训练 MLP 模型，该方法同样适用于大多数常见模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca032358",
   "metadata": {},
   "source": [
    "## 准备工作\n",
    "\n",
    "首先，需要导入必要的依赖项并加载数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fe57ac07",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tvm\n",
    "from tvm.relax.training.loss import CrossEntropyLoss\n",
    "from tvm.relax.training.setup_trainer import SetupTrainer\n",
    "from tvm.relax.training.trainer import Trainer\n",
    "from tvm import relax\n",
    "from tvm.script import ir as I, relax as R\n",
    "from tvm.relax.transform import LegalizeOps\n",
    "from tvm.relax.training.optimizer import SGD\n",
    "\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60ba49b4",
   "metadata": {},
   "source": [
    "将在 [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) 数据集上训练模型。以下代码使用 torchvision（PyTorch 的计算机视觉库）下载并预处理数据。\n",
    "\n",
    "请注意，仅使用 PyTorch 进行数据加载。从 PyTorch Dataloader 加载的数据将在训练过程中转换为 NumPy 数组。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b7be5dbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.utils.data\n",
    "import torchvision\n",
    "import torchvision.transforms as Tr\n",
    "import torch.nn.functional as Func\n",
    "\n",
    "train_data = torchvision.datasets.FashionMNIST(\n",
    "    root=\".temp\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=Tr.Compose([Tr.ToTensor(), Tr.Lambda(torch.flatten)]),\n",
    "    target_transform=lambda x:Func.one_hot(torch.tensor(x), 10).float()\n",
    ")\n",
    "test_data = torchvision.datasets.FashionMNIST(\n",
    "    root=\".temp\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=Tr.Compose([Tr.ToTensor(), Tr.Lambda(torch.flatten)]),\n",
    "    target_transform=lambda x:Func.one_hot(torch.tensor(x), 10).float()\n",
    ")\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True)\n",
    "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n",
    "               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea84b7d0",
   "metadata": {},
   "source": [
    "从数据加载器中取一个样本来看：Fashion MNIST 数据集中的每个样本都是 $28 \\times 28$ 的灰度图像，并属于10种服装类别之一。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7d4e2af3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAesAAAGdCAYAAAAyiFt9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAyQUlEQVR4nO3dfXRU9b3v8c9MSCYB8mCIecLwqEItT70gMaKU1pSAZ3FL5dyF6C3I4eCqJi4ly6NiEXzoMa2eUk57qKzaIu1aotRzq7bqSq9Gg8drwGMsCzmnRkmxQWHCk3kgIZlkZt8/KKMjAfLbM8nsnXm/1tprwc7+zu+XzSbf/H77t/fXY1mWJQAA4FjeeHcAAACcH8kaAACHI1kDAOBwJGsAAByOZA0AgMORrAEAcDiSNQAADkeyBgDA4YbFuwNfFgqFdOjQIaWnp8vj8cS7OwAAQ5Zlqb29XYWFhfJ6B25M2NXVpUAgEPXnpKSkKDU1NQY9GjiOS9aHDh1SUVFRvLsBAIjSwYMHdckllwzIZ3d1dWn82JHyHwlG/Vn5+fk6cOCAoxO245J1enq6JOkaXa9hSo5zbwAApnrVo7f0Svjn+UAIBALyHwnqQP1YZaTbH723tYc0fuZfFQgEEjNZb968WY8//rj8fr+mT5+un/3sZ5o9e/YF485MfQ9TsoZ5SNYA4Dp/qzgxGLcyM9K9USVrtxiQ73DHjh2qrKzUhg0b9N5772n69OkqKyvTkSNHBqI5AECCClqhqDc3GJBkvXHjRq1evVorV67UFVdcoS1btmj48OHaunXrQDQHAEhQIVlRb24Q82nwQCCg+vp6rV27NrzP6/WqtLRUdXV1Zx3f3d2t7u7u8N/b2tpi3SUAwBAVUkjRjI2jix48MR9ZHzt2TMFgUHl5eRH78/Ly5Pf7zzq+qqpKmZmZ4Y2V4AAARIr7Xfm1a9eqtbU1vB08eDDeXQIAuETQsqLe3CDm0+A5OTlKSkpSc3NzxP7m5mbl5+efdbzP55PP54t1NwAACSDa+85uuWcd85F1SkqKZs6cqZqamvC+UCikmpoalZSUxLo5AACGvAF5zrqyslIrVqzQrFmzNHv2bG3atEkdHR1auXLlQDQHAEhQIVkKJsDIekCS9dKlS3X06FGtX79efr9fM2bMUHV19VmLzgAAiEaiTIMP2BvMKioqVFFRMVAfDwBAwnDcu8EBAOivaFd0J+xqcAAABkvob1s08W4Q9+esAQDA+TGyBgC4VjDK1eDRxA4mkjUAwLWC1uktmng3IFkDAFyLe9YAAMARGFkDAFwrJI+C8kQV7wYkawCAa4Ws01s08W7ANDgAAA7HyBoA4FrBKKfBo4kdTCRrAIBrkawB9MuHv5plHFN0yXHjGP97+cYxvSPsPZiS1GV+h8xjo6lhHeY/KJPbzNtJ7rR3Y9LXYv5NpZ7oNW/nkPk3FfzzR8YxcC+SNQDAtUKWRyEritXgUcQOJpI1AMC1EmUanNXgAAA4HCNrAIBrBeVVMIpxZzCGfRlIJGsAgGtZUd6ztrhnDQDAwOKeNQAAcARG1gAA1wpaXgWtKO5Zu+Td4CRrAIBrheRRKIpJ4pDcka2ZBgcAwOEYWQMAXCtRFpiRrAEArhX9PWumwQEAQAwwsgaidOdVrxnHXD3cvGLSB2MLjGMykrqMYyQp1dNjHDMp2byS2F96M41jLk7qMI6ZlpJqHCNJjT0njWM6LPMfq788dq1xzI8L6o1jJGnR4luMY6z/fN9WW4Ph9AKzKAp5MA0OAMDACkX5ulFWgwMAgJhgZA0AcK1EWWBGsgYAuFZI3oR4KQrJGgDgWkHLo2AUlbOiiR1M3LMGAMDhGFkDAFwrGOVq8CDT4AAADKyQ5VUoigVmIZcsMGMaHAAAh2NkDQBwLabBAQBwuJCiW9Edil1XBhTT4AAAOBwjazifx8ZvzYO4aORwIMs45v8GpxrHvPvZWOMYuz5tNy+wcawpyzjGm25eMGRUtnlxjaOfZhnHSJJC5tfeoll/steWoY0nJtuKC6aa/9h38qgu+peiOPm7+xzJGgDgWtG/btQdydodvQQAIIExsgYAuBb1rAEAcLhEmQYnWQMAXCv656zdkazd0UsAABIYI2sAgGuFLI9C0bwUxSUlMknWAADXCkU5De6W56zd0UsAABIYI2sAgGtFXyLTHWNWkjUAwLWC8igYxbPS0cQOJnf8SgEAQAJjZA3n89j4ndIK2moqKS/XOKYg5UPzdmzU0B0z4oRxzInACOMYSbrj0teNYza9/L+MY1quMB/VHPss2zhmzKv2Crt88wdvGcfUHR9vHNMTSjKO6c00j5Gkw3PSjGNG/4etpgYF0+AAADhcUNFNZdv7tX7wueNXCgAAEljMk/WDDz4oj8cTsU2ebK/uKgAA53NmGjyazY7Nmzdr3LhxSk1NVXFxsd55553zHr9p0yZNmjRJaWlpKioq0po1a9TV1dXv9gZkGvyrX/2qXnvttc8bGcZsOwAg9uJRyGPHjh2qrKzUli1bVFxcrE2bNqmsrEwNDQ3KzT173cv27dt13333aevWrbr66qv14Ycf6pZbbpHH49HGjRv71eaAZNFhw4YpPz9/ID4aAIAwK8oSmZaN2I0bN2r16tVauXKlJGnLli16+eWXtXXrVt13331nHf/2229rzpw5uummmyRJ48aN07Jly7R79+5+tzkg96w/+ugjFRYWasKECbr55pvV1NR0zmO7u7vV1tYWsQEAMJi+nIe6u7v7PC4QCKi+vl6lpaXhfV6vV6Wlpaqrq+sz5uqrr1Z9fX14qvwvf/mLXnnlFV1//fX97l/Mk3VxcbG2bdum6upqPfHEEzpw4ICuvfZatbe393l8VVWVMjMzw1tRUVGsuwQAGKLOTINHs0lSUVFRRC6qqqrqs71jx44pGAwqLy8vYn9eXp78fn+fMTfddJMefvhhXXPNNUpOTtbEiRM1b9483X///f3+PmM+Db5w4cLwn6dNm6bi4mKNHTtWv/3tb7Vq1aqzjl+7dq0qKyvDf29rayNhAwD6JVZVtw4ePKiMjIzwfp/PF3XfzqitrdWjjz6qn//85youLtb+/ft155136pFHHtEDDzzQr88Y8JVfWVlZuvzyy7V///4+v+7z+WJ6UgAAMJWRkRGRrM8lJydHSUlJam5ujtjf3Nx8zrVaDzzwgL773e/qH//xHyVJU6dOVUdHh2699VZ9//vfl9d74UnuAX/O+uTJk2psbFRBQcFANwUASDDBv5XIjGYzkZKSopkzZ6qmpia8LxQKqaamRiUlJX3GdHZ2npWQk5JOv4HOsvr3dr2Yj6zvvvtuLVq0SGPHjtWhQ4e0YcMGJSUladmyZbFuCgCQ4GI1DW6isrJSK1as0KxZszR79mxt2rRJHR0d4dXhy5cv1+jRo8P3vRctWqSNGzfqa1/7Wnga/IEHHtCiRYvCSftCYp6sP/nkEy1btkzHjx/XxRdfrGuuuUa7du3SxRdfHOumAAAYdEuXLtXRo0e1fv16+f1+zZgxQ9XV1eFFZ01NTREj6XXr1snj8WjdunX69NNPdfHFF2vRokX653/+53636bH6OwYfJG1tbcrMzNQ8fVvDPMnx7g4cwGPjpTpWb6+ttkLXfs045nu/+j/GMfu78y580JeMTOr/247OaOoeZRwjSe29qcYx7x4xXxi6Yvwu4xg7b5z6Sc0C4xhJ+p9z6o1jPunMMo6xM7orTLP3mGv1h1cYx0y8+U9Gx/daParVi2ptbe3XfWA7zuSKire+I99I+7mi+2SP/u2a5we0r7HAq8UAAK4VtDwKRjENHk3sYKKQBwAADsfIGgDgWvFYYBYPJGsAgGtZUVTOOhPvBiRrAIBrBeVRMIpCHtHEDiZ3/EoBAEACY2QNAHCtkBXdfeeQox5ePjeSNQDAtUJR3rOOJnYwuaOXAAAkMEbWAADXCsmjUBSLxKKJHUwkawCAa/EGMwAA4AiMrOF8/SwhF8FmIY+gz/z319HDPjOO+VPnWOOYYz0jjWMOd2Uax0jSiGHd5kHP5RiH/HjefOMY79EU45iC/7S35Hdy6WHjmDcOXmYck55qfr4vSjllHCNJvtSArTinSpQFZiRrAIBrhRTl60Zdcs/aHb9SAACQwBhZAwBcy4pyNbjlkpE1yRoA4FpU3QIAwOESZYGZO3oJAEACY2QNAHAtpsEBAHC4RHndKNPgAAA4HCNrAIBrMQ0OAIDDJUqyZhocAACHY2QNAHCtRBlZk6zheFaPvQpadgRTzSt8HQ+NMI45GfQZx1w0rNM4Zu5FHxrHSNJLR6YZx6Tc2Gwckxsyn9xLyg8Zx3zWmm8cI0mPv2deFezpq39pHLO3u8g4pjNkXn1Mkv4zybwtJ0uUZM00OAAADsfIGgDgWpaie1baXqXzwUeyBgC4VqJMg5OsAQCulSjJmnvWAAA4HCNrAIBrJcrImmQNAHCtREnWTIMDAOBwjKwBAK5lWR5ZUYyOo4kdTCRrAIBrUc8aAAA4AiNrAIBrJcoCM5I1nC8UHLSmDl1jXsjj055s4xg7PyD+q73AOKZ7hL3/4j8a+zvjmKdbio1j7JyHkUndxjHf/IfnjGMkyesxLxrydudlxjH/dXK0cUxrT6pxjCQVZrQZxzj5lZyJcs+aaXAAAByOkTUAwLWYBgcAwOESZRqcZA0AcC0rypG1W5I196wBAHA4RtYAANeyJFlRLFd38kr3LyJZAwBcKySPPLzBDAAAxBsjawCAa7EaHAAAhwtZHnkS4DlrpsEBAHA4RtYAANeyrChXg7tkOTjJGviCSVd9bBxzMmheUMHrMf8JcVFKp3HMx52jjGMkaWPgW8Yxp4LJxjEtgTTjmBRvr3HMoe4s4xhJOhYYYRzj78gwjrEzFesbZn4eJOm7o+uMY569bK7R8VawW2o0bsaWRLlnzTQ4AAAOx8gaAOBajKzP4c0339SiRYtUWFgoj8ejF154IeLrlmVp/fr1KigoUFpamkpLS/XRRx/Fqr8AAISdqboVzeYGxsm6o6ND06dP1+bNm/v8+mOPPaaf/vSn2rJli3bv3q0RI0aorKxMXV1dUXcWAIAvOrPALJrNDYynwRcuXKiFCxf2+TXLsrRp0yatW7dO3/72tyVJv/nNb5SXl6cXXnhBN954Y3S9BQAgAcV0gdmBAwfk9/tVWloa3peZmani4mLV1fW9ArG7u1ttbW0RGwAA/XF6dOyJYov3d9A/MU3Wfr9fkpSXlxexPy8vL/y1L6uqqlJmZmZ4KyoqimWXAABDWHSJOrrFaYMp7o9urV27Vq2treHt4MGD8e4SAACOEtNHt/Lz8yVJzc3NKigoCO9vbm7WjBkz+ozx+Xzy+Xyx7AYAIEFYiq4mtUtmwWM7sh4/frzy8/NVU1MT3tfW1qbdu3erpKQklk0BAMA0+LmcPHlSe/bs0Z49eySdXlS2Z88eNTU1yePx6K677tIPfvAD/f73v9f777+v5cuXq7CwUIsXL45x1wEAiI/Nmzdr3LhxSk1NVXFxsd55553zHt/S0qLy8nIVFBTI5/Pp8ssv1yuvvNLv9oynwd9991194xvfCP+9srJSkrRixQpt27ZN99xzjzo6OnTrrbeqpaVF11xzjaqrq5Waav7+ZAAAzisO8+A7duxQZWWltmzZouLiYm3atEllZWVqaGhQbm7uWccHAgF961vfUm5urv793/9do0eP1l//+ldlZWX1u03jZD1v3jxZ51nr7vF49PDDD+vhhx82/WgkAm+SeUwoaBzisbkO4vFx/8c45umWYuOYkUndxjHJHvPzYFdR6mfGMX89ZV40pDfZ/E5cb8g8Zt9nBRc+qA85aSeNY7J8p4xjQjKfim0P2LvGm3szjWNa/sfZCeh8enu6Bq2Qh6KdyrYRu3HjRq1evVorV66UJG3ZskUvv/yytm7dqvvuu++s47du3aoTJ07o7bffVnLy6YI348aNM2oz7qvBAQCwa7DfYBYIBFRfXx/xPhGv16vS0tJzvk/k97//vUpKSlReXq68vDxNmTJFjz76qILB/v8CTiEPAEDC+/ILuc71pNKxY8cUDAb7fJ/IBx980Odn/+Uvf9Hrr7+um2++Wa+88or279+v22+/XT09PdqwYUO/+sfIGgDgWrFaDV5UVBTxgq6qqqqY9TEUCik3N1e/+MUvNHPmTC1dulTf//73tWXLln5/BiNrAIB7WR5b950j4iUdPHhQGRkZ4d3nev9HTk6OkpKS1NzcHLG/ubk5/K6RLysoKFBycrKSkj5fs/OVr3xFfr9fgUBAKSkpF+wmI2sAQMLLyMiI2M6VrFNSUjRz5syI94mEQiHV1NSc830ic+bM0f79+xUKhcL7PvzwQxUUFPQrUUskawCAi8WjRGZlZaWefPJJ/frXv9af//xn3Xbbbero6AivDl++fLnWrl0bPv62227TiRMndOedd+rDDz/Uyy+/rEcffVTl5eX9bpNpcACAe8XhOeulS5fq6NGjWr9+vfx+v2bMmKHq6urworOmpiZ5vZ+PhYuKivTHP/5Ra9as0bRp0zR69Gjdeeeduvfee/vdJskaAABDFRUVqqio6PNrtbW1Z+0rKSnRrl27bLdHsgYAuFa07/d2y7vBSdYAAHdzS+msKLDADAAAh2NkDQBwLabBAQBwujisBo8HkjUGlSfZ/JKzus2rTbUu+ZpxjCTt7jpgHNNjmVcSs1M5qjuUbByT7Ald+KA+HOrOMo7p6O3fyx2idVGKeVWrFO/gVSzrHKTzkJ5iXrlNkgqTW4xjTl5idr0GuwfzDqvnb1s08c7HPWsAAByOkTUAwL2YBgcAwOESJFkzDQ4AgMMxsgYAuFeMSmQ6HckaAOBaditnfTHeDZgGBwDA4RhZAwDcK0EWmJGsAQDulSD3rJkGBwDA4RhZAwBcy2Od3qKJdwOSNQDAvbhnDcSe1W2v+IBxO//7mK04r8wLX3T0+oxjfN4e45gey/yuVchGjF2BkJ2CJuYxw2wWJ7Gj18b5y/Z1Gsd4bXxPds6dJHWEzK/XUf9ldr329phf37ZxzxoAADgBI2sAgHsxDQ4AgMMlSLJmGhwAAIdjZA0AcK8EGVmTrAEA7sVqcAAA4ASMrAEArsUbzAAAcLoEuWfNNDgAAA5HsgYAwOGYBgcAuJZHUd6zjllPBhbJGo7nSU4xjnnmq9tstfV06yzjGHtFOcyLMCTZ+Il0Kmiv2EOyjcISw4cFjGPsFKPISO4yjgkO4uM5ds5dm41iMG09qcYxktTYlWsck/r/PjA6vtcyvxZs49EtAADgBIysAQDulSCrwUnWAAD3SpBkzTQ4AAAOx8gaAOBavMEMAACnYxocAAA4ASNrAIB7JcjImmQNAHCtRLlnzTQ4AAAOx8gaAOBeCfK6UZI1AMC9uGcNOEPTfebFNf7Y8amttj7rGW4c4/P2Gsd4LfOfEMlW0DimzbJX7MG8NIk00kYhDzsFNtK85u209aYZx0hSu40CGy3d5m21B8zbsVPYRZIKUlqMY95PyjQ63mMN3h1W7lkDAABHYGQNAHCvBJkGNx5Zv/nmm1q0aJEKCwvl8Xj0wgsvRHz9lltukcfjidgWLFgQq/4CAPA56/OpcDvbkE3WHR0dmj59ujZv3nzOYxYsWKDDhw+Ht2eeeSaqTgIAkMiMp8EXLlyohQsXnvcYn8+n/Px8250CAKBfmAa3r7a2Vrm5uZo0aZJuu+02HT9+/JzHdnd3q62tLWIDAKBfrBhsLhDzZL1gwQL95je/UU1NjX70ox9p586dWrhwoYLBvh87qaqqUmZmZngrKiqKdZcAAHC1mK8Gv/HGG8N/njp1qqZNm6aJEyeqtrZW11133VnHr127VpWVleG/t7W1kbABAP3Cc9YxMmHCBOXk5Gj//v19ft3n8ykjIyNiAwAAnxvwZP3JJ5/o+PHjKigoGOimAAAYkoynwU+ePBkxSj5w4ID27Nmj7OxsZWdn66GHHtKSJUuUn5+vxsZG3XPPPbr00ktVVlYW044DAJAoq8GNk/W7776rb3zjG+G/n7nfvGLFCj3xxBPau3evfv3rX6ulpUWFhYWaP3++HnnkEfl85u++BQDgfBLlnrVxsp43b56s8xQh+OMf/xhVh4Ysj40ybDaKPQxFNyz5D+OY5h6zwgNn5Ka0G8e0B82LZYRC5tfDqVCKcYzXEzKOsautx/w85PnMH9U8GTT/xf/wKXtrYY6fMi/sMjzZvAzKqLRO45isFPMYSeoOJRvHBFtazY637JSCiUIC/KikkAcAAA5HIQ8AgHtxzxoAAGdLlHvWTIMDAOBwjKwBAO7FNDgAAM7GNDgAAOjT5s2bNW7cOKWmpqq4uFjvvPNOv+KeffZZeTweLV682Kg9kjUAwL3iUCJzx44dqqys1IYNG/Tee+9p+vTpKisr05EjR84b9/HHH+vuu+/Wtddea9wmyRoA4F5xSNYbN27U6tWrtXLlSl1xxRXasmWLhg8frq1bt54zJhgM6uabb9ZDDz2kCRMmGLdJsgYAJLy2traIrbu7u8/jAoGA6uvrVVpaGt7n9XpVWlqqurq6c37+ww8/rNzcXK1atcpW/0jWAADXOrPALJpNkoqKipSZmRneqqqq+mzv2LFjCgaDysvLi9ifl5cnv9/fZ8xbb72lX/3qV3ryySdtf5+sBgcAuFeMHt06ePCgMjI+f4d8rIpPtbe367vf/a6efPJJ5eTk2P4ckjUAwL1ilKwzMjIikvW55OTkKCkpSc3NzRH7m5ublZ+ff9bxjY2N+vjjj7Vo0aLwvlDodIGdYcOGqaGhQRMnTrxguyRrDKoDj5YYxzye9RPjmF8c/bpxjCQN8waNY3zeXuMYO5WPuoPm/119Nr4fSUpLChjHtPSkGcc0d5tXw2o+lW4cE7JsVL2TvWpYGcldxjGBUJJxTFvA/HxL0oneEbbicFpKSopmzpypmpqa8ONXoVBINTU1qqioOOv4yZMn6/3334/Yt27dOrW3t+tf//VfVVRU1K92SdYAANeKx0tRKisrtWLFCs2aNUuzZ8/Wpk2b1NHRoZUrV0qSli9frtGjR6uqqkqpqamaMmVKRHxWVpYknbX/fEjWAAD3isPrRpcuXaqjR49q/fr18vv9mjFjhqqrq8OLzpqamuT1xnb9NskaAABDFRUVfU57S1Jtbe15Y7dt22bcHskaAOBaifJucJI1AMC9EqTqFi9FAQDA4RhZAwDcK0FG1iRrAIBref62RRPvBkyDAwDgcIysAQDuxTQ4AADOxqNbAAA4HSNrnJPHxpIEj/nyAE/S4LRj9ZgXbZCkpMsmGMc8s+xfjWN+eexa4xg7BTkkKdljHtcdcu5/I68nZCvuSLd5sYxA0LwYxYku86ISOWknjWMybRTXkKTWnlTjmC4bBVdabRTlGJXaYRwjSce6R9qIMj/niC3n/pQBAKA/XDI6jgbJGgDgWolyz5pHtwAAcDhG1gAA92KBGQAAzsY0OAAAcARG1gAA92IaHAAAZ2MaHAAAOAIjawCAezENDgCAw5GsAQBwtkS5Zz10krXXvIiAQvaKPciy8a9rmbdl2avBMGhyfnPMOObXx+cYx5wImBd7mDzSbxwjScd67BQ5MGen+EdHMMU4xhscvJ9EhzsyjGMuyzpqHBOwce4OnzLvm10hy7wATzBkvnwofVi3cYwk/fdnecYxGRTyiLuhk6wBAImHaXAAAJzNY1ny2Jnt/EK8G/DoFgAADsfIGgDgXkyDAwDgbImyGpxpcAAAHI6RNQDAvZgGBwDA2ZgGBwAAjsDIGgDgXkyDAwDgbIkyDU6yBgC4FyNrl7FblGOQJF18sXFM56xxxjHtY8z/SUfcYK/oRbLHvAhDr2VecGV0WotxzGc9w41jJHtFItK8AeMYO8UeAkEbxWpsau5MN46Zkm1+HXWHzL8nv42CIUlee1VxRiSb/9t6bfz0D1rmy4eGee39zDvRZl4YZ/DKoOBchk6yBgAkJLdMZUeDZA0AcC/Lsle2+IvxLmA091JVVaUrr7xS6enpys3N1eLFi9XQ0BBxTFdXl8rLyzVq1CiNHDlSS5YsUXNzc0w7DQBAIjFK1jt37lR5ebl27dqlV199VT09PZo/f746OjrCx6xZs0Z/+MMf9Nxzz2nnzp06dOiQbrjhhph3HACAM6vBo9ncwGgavLq6OuLv27ZtU25ururr6zV37ly1trbqV7/6lbZv365vfvObkqSnnnpKX/nKV7Rr1y5dddVVses5AAAJsho8qjeYtba2SpKys7MlSfX19erp6VFpaWn4mMmTJ2vMmDGqq6vr8zO6u7vV1tYWsQEAgM/ZTtahUEh33XWX5syZoylTpkiS/H6/UlJSlJWVFXFsXl6e/P6+H+uoqqpSZmZmeCsqKrLbJQBAgvGEot/cwHayLi8v1759+/Tss89G1YG1a9eqtbU1vB08eDCqzwMAJBArBpsL2Hp0q6KiQi+99JLefPNNXXLJJeH9+fn5CgQCamlpiRhdNzc3Kz8/v8/P8vl88vl8droBAEBCMBpZW5aliooKPf/883r99dc1fvz4iK/PnDlTycnJqqmpCe9raGhQU1OTSkpKYtNjAAD+htXgfSgvL9f27dv14osvKj09PXwfOjMzU2lpacrMzNSqVatUWVmp7OxsZWRk6I477lBJSQkrwQEAsZcgL0UxStZPPPGEJGnevHkR+5966indcsstkqSf/OQn8nq9WrJkibq7u1VWVqaf//znMeksAABfRNWtPlj9+A0kNTVVmzdv1ubNm213arB0fqfYVlzLZebFB3rM350v3wnzmPRPzF/uf/BQtnlDknpC5usTLx7eceGDvuSS4S3GMSneXuMYScoa1mkc0xlKMY453m3jgrDh6Cl77RSltxjH+Gyc849Pml97Wb5TxjG9NgplSNIwj/n/p85e8+vBTmEXO+dbknqOptmKQ3zxbnAAgHslyEtRSNYAANdKlGnwqN5gBgAABh4jawCAe7EaHAAAZ2MaHAAAOAIjawCAe7EaHAAAZ2MaHAAAOAIjawCAe4Ws01s08S5AsgYAuBf3rAEAcDaPorxnHbOeDCzuWQMA4HBDZmTd+9oY45hPD/XYamvEPvOqWymt5u30DjePCYw0//3L05Zs3pAk32jzikReG3NOx2xUqErxmvdNknJ8J41j2ntSjWO8npBxzIku8/OQ7DVvR5Jyfe3GMXtPjDaOGZnSbRxj59zJbtUtG+dv+LCAccyRznTjGLuSW4fYGC1B3mA2xP7VAACJ5MyjW9FsdmzevFnjxo1TamqqiouL9c4775zz2CeffFLXXnutLrroIl100UUqLS097/F9IVkDAGBgx44dqqys1IYNG/Tee+9p+vTpKisr05EjR/o8vra2VsuWLdMbb7yhuro6FRUVaf78+fr000/73SbJGgDgXlYMNkMbN27U6tWrtXLlSl1xxRXasmWLhg8frq1bt/Z5/NNPP63bb79dM2bM0OTJk/XLX/5SoVBINTU1/W6TZA0AcC2PZUW9SVJbW1vE1t3d93qKQCCg+vp6lZaWhvd5vV6Vlpaqrq6uX33u7OxUT0+PsrOz+/19kqwBAAmvqKhImZmZ4a2qqqrP444dO6ZgMKi8vLyI/Xl5efL7/f1q695771VhYWFEwr+QIbMaHACQgEJ/26KJl3Tw4EFlZGSEd/t8vqi6dS4//OEP9eyzz6q2tlapqf1/koRkDQBwrS9OZduNl6SMjIyIZH0uOTk5SkpKUnNzc8T+5uZm5efnnzf2X/7lX/TDH/5Qr732mqZNm2bUT6bBAQDop5SUFM2cOTNicdiZxWIlJSXnjHvsscf0yCOPqLq6WrNmzTJul5E1AMC94vBu8MrKSq1YsUKzZs3S7NmztWnTJnV0dGjlypWSpOXLl2v06NHh+94/+tGPtH79em3fvl3jxo0L39seOXKkRo4c2a82SdYAAPeKwxvMli5dqqNHj2r9+vXy+/2aMWOGqqurw4vOmpqa5PV+PnH9xBNPKBAI6O///u8jPmfDhg168MEH+9UmyRoA4FrRvIXsTLwdFRUVqqio6PNrtbW1EX//+OOP7TXyBdyzBgDA4Rw7sv7knmIl+fq/rH1l/h+N23gjaZJxjCQNG2v+nECKt9c4JjfVvKjEbr95QZOk/x5lHCNJIcu8uFzhcPOKJmle88II3SF7xUlOBMyrp5zsMX/Eo9cyLwZzvMO8b4vG7TOOkaT/biuwFWdqmI2iHIGQ+Y+tkwF7j+F0B83bykszL4KS4esyjum2cR4kafhhtxSF7KcEKeTh2GQNAMCFeEKnt2ji3YBpcAAAHI6RNQDAvZgGBwDA4eLwnHU8MA0OAIDDMbIGALhWrN4N7nQkawCAeyXIPWumwQEAcDhG1gAA97IUXT1rdwysSdYAAPfinjUAAE5nKcp71jHryYDinjUAAA7n2JH1RR8GNSw52O/jM5NOGbcxYeRx4xhJ6rHMf8c5FTQvLGGniMDN4981jnkz4zLjGEk62JZpHDPM0/9/0zMyhpkXOfikK804RpKOnEo3jrFTJMJ/3Pzcjc8/Zhyzt3W0cYwkNbVkGcdkpZn/O41M7jaOCYTMi6BkJJv3za4T3eYFV+wUxbHz/0KSkttdMpTsrwRZDe7YZA0AwAWFJEVTSIxCHgAAIBYYWQMAXIvV4AAAOF2C3LNmGhwAAIdjZA0AcK8EGVmTrAEA7pUgyZppcAAAHI6RNQDAvRLkOWuSNQDAtXh0CwAAp+OeNQAAcALHjqxHvPCuhnn6X/zi2ZaFxm00/YN5UQlJmn/ZB8YxxZkHbLVl6tPui4xjJmU022rrkuEttuJMHeo2L3rR3mteXEOyV1hi7MgTxjH/e8xu45hvjWgwjnk/kG8cI0npReaFcVqCI4xj/L3m/7Z2rvGPO0cZx0jSuOHmxX4yM8zP3afdWcYxp4IpxjGSNPJwj604xwpZkieK0XHIHSNrxyZrAAAuiGlwAADgBEbJuqqqSldeeaXS09OVm5urxYsXq6Ehcmpu3rx58ng8Edv3vve9mHYaAIDTrM9H13Y2DcGR9c6dO1VeXq5du3bp1VdfVU9Pj+bPn6+Ojo6I41avXq3Dhw+Ht8ceeyymnQYAQFJ0iTraKfRBZHTPurq6OuLv27ZtU25ururr6zV37tzw/uHDhys/397CFgAAECmqe9atra2SpOzs7Ij9Tz/9tHJycjRlyhStXbtWnZ2d5/yM7u5utbW1RWwAAPRLyIp+cwHbq8FDoZDuuusuzZkzR1OmTAnvv+mmmzR27FgVFhZq7969uvfee9XQ0KDf/e53fX5OVVWVHnroIbvdAAAkMit0eosm3gVsJ+vy8nLt27dPb731VsT+W2+9NfznqVOnqqCgQNddd50aGxs1ceLEsz5n7dq1qqysDP+9ra1NRUVFdrsFAMCQYytZV1RU6KWXXtKbb76pSy655LzHFhcXS5L279/fZ7L2+Xzy+ey9wAIAkOAS5Dlro2RtWZbuuOMOPf/886qtrdX48eMvGLNnzx5JUkFBga0OAgBwTqEoH78aivesy8vLtX37dr344otKT0+X3++XJGVmZiotLU2NjY3avn27rr/+eo0aNUp79+7VmjVrNHfuXE2bNm1AvgEAQAJjZH22J554QtLpF5980VNPPaVbbrlFKSkpeu2117Rp0yZ1dHSoqKhIS5Ys0bp162LWYQAAEo3xNPj5FBUVaefOnVF1CACAfrMU5cg6Zj0ZUEOmkEfya/XGMRNfs9dWo42YjycVG8ecGm9eXejEpP5XKjujfaK96mOe7IBxTGqaeYzHYxwij80qPHbiOj+6zDjmrc++ZhzzzN5e45jkDvMYSUo50nHhg74k+F/mVcEGz2e2oo7behWFefUxafAqYSXr3UFra1AkyDQ4hTwAAHC4ITOyBgAkoFBIUhQvNgkN8ZeiAAAQd0yDAwAAJ2BkDQBwrwQZWZOsAQDulSBvMGMaHAAAh2NkDQBwLcsKyYqizGU0sYOJZA0AcC/Lim4qm3vWAAAMMCvKe9YuSdbcswYAwOEYWQMA3CsUkjxR3HfmnjW+KNiw3zgmxUZdhPxqGzHmIRji7JV2AeKAaXAAAOAEjKwBAK5lhUKyopgG59EtAAAGGtPgAADACRhZAwDcK2RJnqE/siZZAwDcy7IkRfPoljuSNdPgAAA4HCNrAIBrWSFLVhTT4BYjawAABpgVin6zYfPmzRo3bpxSU1NVXFysd95557zHP/fcc5o8ebJSU1M1depUvfLKK0btkawBAK5lhayoN1M7duxQZWWlNmzYoPfee0/Tp09XWVmZjhw50ufxb7/9tpYtW6ZVq1bpT3/6kxYvXqzFixdr3759/W7TYzlsDqCtrU2ZmZmap29rmCc53t0BABjqtXpUqxfV2tqqjIyMAWkjnCs834kqV/RaPaq1njfqa3Fxsa688kr927/9myQpFAqpqKhId9xxh+67776zjl+6dKk6Ojr00ksvhfddddVVmjFjhrZs2dKvNh13z/rM7w696onqOXcAQHz0qkfS4NwP7rW6oyrGcaavbW1tEft9Pp98Pt9ZxwcCAdXX12vt2rXhfV6vV6Wlpaqrq+uzjbq6OlVWVkbsKysr0wsvvNDvfjouWbe3t0uS3pLZfD4AwFna29uVmZk5IJ+dkpKi/Px8veWPPleMHDlSRUVFEfs2bNigBx988Kxjjx07pmAwqLy8vIj9eXl5+uCDD/r8fL/f3+fxfr+/3310XLIuLCzUwYMHlZ6eLo/HE/G1trY2FRUV6eDBgwM2teIGnIfTOA+ncR5O4zyc5oTzYFmW2tvbVVhYOGBtpKam6sCBAwoEAlF/lmVZZ+WbvkbV8eS4ZO31enXJJZec95iMjIyE/s94BufhNM7DaZyH0zgPp8X7PAzUiPqLUlNTlZqaOuDtfFFOTo6SkpLU3Nwcsb+5uVn5+X0XHM7Pzzc6vi+sBgcAoJ9SUlI0c+ZM1dTUhPeFQiHV1NSopKSkz5iSkpKI4yXp1VdfPefxfXHcyBoAACerrKzUihUrNGvWLM2ePVubNm1SR0eHVq5cKUlavny5Ro8eraqqKknSnXfeqa9//ev68Y9/rL/7u7/Ts88+q3fffVe/+MUv+t2mq5K1z+fThg0bHHcvYbBxHk7jPJzGeTiN83Aa52HgLV26VEePHtX69evl9/s1Y8YMVVdXhxeRNTU1yev9fOL66quv1vbt27Vu3Trdf//9uuyyy/TCCy9oypQp/W7Tcc9ZAwCASNyzBgDA4UjWAAA4HMkaAACHI1kDAOBwrknWpuXIhqIHH3xQHo8nYps8eXK8uzXg3nzzTS1atEiFhYXyeDxnvU/XsiytX79eBQUFSktLU2lpqT766KP4dHYAXeg83HLLLWddHwsWLIhPZwdIVVWVrrzySqWnpys3N1eLFy9WQ0NDxDFdXV0qLy/XqFGjNHLkSC1ZsuSsF1K4XX/Ow7x58866Hr73ve/FqceIliuStWk5sqHsq1/9qg4fPhze3nrrrXh3acB1dHRo+vTp2rx5c59ff+yxx/TTn/5UW7Zs0e7duzVixAiVlZWpq6trkHs6sC50HiRpwYIFEdfHM888M4g9HHg7d+5UeXm5du3apVdffVU9PT2aP3++Ojo6wsesWbNGf/jDH/Tcc89p586dOnTokG644YY49jr2+nMeJGn16tUR18Njjz0Wpx4japYLzJ492yovLw//PRgMWoWFhVZVVVUcezX4NmzYYE2fPj3e3YgrSdbzzz8f/nsoFLLy8/Otxx9/PLyvpaXF8vl81jPPPBOHHg6OL58Hy7KsFStWWN/+9rfj0p94OXLkiCXJ2rlzp2VZp//tk5OTreeeey58zJ///GdLklVXVxevbg64L58Hy7Ksr3/969add94Zv04hphw/sj5Tjqy0tDS870LlyIayjz76SIWFhZowYYJuvvlmNTU1xbtLcXXgwAH5/f6I6yMzM1PFxcUJeX3U1tYqNzdXkyZN0m233abjx4/Hu0sDqrW1VZKUnZ0tSaqvr1dPT0/E9TB58mSNGTNmSF8PXz4PZzz99NPKycnRlClTtHbtWnV2dsaje4gBx7/BzE45sqGquLhY27Zt06RJk3T48GE99NBDuvbaa7Vv3z6lp6fHu3txcabEXLTl54aCBQsW6IYbbtD48ePV2Nio+++/XwsXLlRdXZ2SkpLi3b2YC4VCuuuuuzRnzpzwm6D8fr9SUlKUlZUVcexQvh76Og+SdNNNN2ns2LEqLCzU3r17de+996qhoUG/+93v4thb2OX4ZI3PLVy4MPznadOmqbi4WGPHjtVvf/tbrVq1Ko49gxPceOON4T9PnTpV06ZN08SJE1VbW6vrrrsujj0bGOXl5dq3b19CrNs4n3Odh1tvvTX856lTp6qgoEDXXXedGhsbNXHixMHuJqLk+GlwO+XIEkVWVpYuv/xy7d+/P95diZsz1wDXx9kmTJignJycIXl9VFRU6KWXXtIbb7wRUVI3Pz9fgUBALS0tEccP1evhXOehL8XFxZI0JK+HROD4ZG2nHFmiOHnypBobG1VQUBDvrsTN+PHjlZ+fH3F9tLW1affu3Ql/fXzyySc6fvz4kLo+LMtSRUWFnn/+eb3++usaP358xNdnzpyp5OTkiOuhoaFBTU1NQ+p6uNB56MuePXskaUhdD4nEFdPgFypHlijuvvtuLVq0SGPHjtWhQ4e0YcMGJSUladmyZfHu2oA6efJkxGjgwIED2rNnj7KzszVmzBjddddd+sEPfqDLLrtM48eP1wMPPKDCwkItXrw4fp0eAOc7D9nZ2XrooYe0ZMkS5efnq7GxUffcc48uvfRSlZWVxbHXsVVeXq7t27frxRdfVHp6evg+dGZmptLS0pSZmalVq1apsrJS2dnZysjI0B133KGSkhJdddVVce597FzoPDQ2Nmr79u26/vrrNWrUKO3du1dr1qzR3LlzNW3atDj3HrbEezl6f/3sZz+zxowZY6WkpFizZ8+2du3aFe8uDbqlS5daBQUFVkpKijV69Ghr6dKl1v79++PdrQH3xhtvWJLO2lasWGFZ1unHtx544AErLy/P8vl81nXXXWc1NDTEt9MD4HznobOz05o/f7518cUXW8nJydbYsWOt1atXW36/P97djqm+vn9J1lNPPRU+5tSpU9btt99uXXTRRdbw4cOt73znO9bhw4fj1+kBcKHz0NTUZM2dO9fKzs62fD6fdemll1r/9E//ZLW2tsa347CNEpkAADic4+9ZAwCQ6EjWAAA4HMkaAACHI1kDAOBwJGsAAByOZA0AgMORrAEAcDiSNQAADkeyBgDA4UjWAAA4HMkaAACHI1kDAOBw/x+M9hgTTUEUzwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Class: Ankle boot\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "img, label = next(iter(train_loader))\n",
    "img = img[0].reshape(1, 28, 28).numpy()\n",
    "plt.figure()\n",
    "plt.imshow(img[0])\n",
    "plt.colorbar()\n",
    "plt.grid(False)\n",
    "plt.show()\n",
    "print(\"Class:\", class_names[label.argmax()])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a471ba68",
   "metadata": {},
   "source": [
    "## 模型定义\n",
    "\n",
    "将使用三层感知机（Perceptron）进行图像分类。首先需要定义该感知机的主干结构："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9e66f7e6",
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class MLP:\n",
    "    I.module_attrs({\"param_num\": 6, \"state_num\": 0})\n",
    "    @R.function\n",
    "    def backbone(\n",
    "        x: R.Tensor((batch_size, 784), \"float32\"),\n",
    "        w0: R.Tensor((784, 128), \"float32\"),\n",
    "        b0: R.Tensor((128,), \"float32\"),\n",
    "        w1: R.Tensor((128, 128), \"float32\"),\n",
    "        b1: R.Tensor((128,), \"float32\"),\n",
    "        w2: R.Tensor((128, 10), \"float32\"),\n",
    "        b2: R.Tensor((10,), \"float32\"),\n",
    "    ) -> R.Tensor((batch_size, 10), \"float32\"):\n",
    "        with R.dataflow():\n",
    "            lv0 = R.matmul(x, w0)\n",
    "            lv1 = R.add(lv0, b0)\n",
    "            lv2 = R.nn.relu(lv1)\n",
    "            lv3 = R.matmul(lv2, w1)\n",
    "            lv4 = R.add(lv3, b1)\n",
    "            lv5 = R.nn.relu(lv4)\n",
    "            lv6 = R.matmul(lv5, w2)\n",
    "            out = R.add(lv6, b2)\n",
    "            R.output(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a011868",
   "metadata": {},
   "source": [
    "## 方法一：使用训练器 API\n",
    "\n",
    "### 训练器结构\n",
    "\n",
    "训练给定模型的更简单方式是使用训练器 API。该 API 提供了参数更新和模型推理的核心接口。\n",
    "\n",
    "构建训练器时，需要先创建优化器和损失函数。我们只需指定**超参数**（如学习率、归约方法等）即可完成构建，在此阶段无需提供模型参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e9d28617",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = CrossEntropyLoss(reduction=\"sum\")\n",
    "opt = SGD(0.01, weight_decay=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2e537b1",
   "metadata": {},
   "source": [
    "随后，需要构建 `SetupTrainer`。这是 `Trainer` 的辅助类，本质上是变换（pass），用于将主干模块转换为完整且规范化的训练器模块。\n",
    "\n",
    "变换后的模块将包含以下方法：\n",
    "- `predict`: 模型预测方法（由输入模块提供）\n",
    "- `loss`: 计算预测结果与真实标签之间的指定损失\n",
    "- `loss_adjoint`: 计算损失值及参数的伴随梯度\n",
    "- `update_params`: 接收参数、参数梯度和优化器状态作为输入，返回更新后的参数和新优化器状态。该方法包含名为 `optim_state` 的函数属性，表示指定优化器的初始状态。\n",
    "\n",
    "构建 `SetupTrainer` 需要指定以下要素：\n",
    "1. 损失函数\n",
    "2. 优化器\n",
    "3. 模型输出和标签的 `struct_info`（结构信息）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f8efe5c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "out_sinfo = relax.TensorStructInfo((batch_size, 10), \"float32\")\n",
    "label_sinfo = relax.TensorStructInfo((batch_size, 10), \"int64\")\n",
    "\n",
    "setup_trainer = SetupTrainer(loss, opt, [out_sinfo, label_sinfo])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a9c86f0",
   "metadata": {},
   "source": [
    "最后一步，引入 `Trainer`。`Trainer` 是运行时组件，通过 SetupTrainer 配置主干模块结构后构建并运行模块，同时内部维护参数的运行时值。\n",
    "\n",
    "构建 Trainer 需要指定以下要素：\n",
    "1. 主干模块（Backbone）\n",
    "2. 参数数量 $n$\n",
    "3. `SetupTrainer` 实例\n",
    "\n",
    "主干函数的前 $n$ 个参数将被识别为模型参数，这些参数将在训练过程中被优化器更新。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "852105ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "target = \"llvm\"\n",
    "dev = tvm.device(target, 0)\n",
    "train_mod = setup_trainer(MLP)\n",
    "ex = tvm.compile(train_mod, target)\n",
    "vm = relax.VirtualMachine(ex, dev, profile=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e48be560",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(Backbone, 6, setup_trainer)\n",
    "# build the IRModule in the trainer\n",
    "trainer.build(target=\"llvm\", device=tvm.cpu(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bfdaf2e",
   "metadata": {},
   "source": [
    "### 训练流程\n",
    "\n",
    "训练器构建完成后，即可在其基础上执行标准训练流程。我们将随机初始化参数，并进行 5 轮（epoch）训练。\n",
    "\n",
    "`Trainer` 提供 `xaiver_uniform_init_params` 方法（注：应为 Xavier Uniform 初始化），用于通过 Xavier 均匀分布初始化所有参数。若需自定义参数初始化，可调用以下方法：\n",
    "- `trainer.load_params(extern_param_dict: Dict[str, Union[np.ndarray, NDArray]])` 加载预设参数\n",
    "- `trainer.export_params() -> Dict[str, NDArray]` 导出当前参数\n",
    "\n",
    "`update_params` 方法将用于参数更新，其内部执行流程如下：\n",
    "1. **前向传播**：获取模型输出及损失值\n",
    "2. **梯度计算**：计算参数梯度\n",
    "3. **参数更新**：根据优化器算法更新参数\n",
    "4. **返回损失**：将当前损失值返回调用方\n",
    "\n",
    "`predict` 方法专为推理设计，接收一批特征数据并返回预测结果（即主干网络的输出）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc827acc",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.xaiver_uniform_init_params()\n",
    "\n",
    "\n",
    "epochs = 5\n",
    "log_interval = 200\n",
    "\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        loss = trainer.update_params(data.numpy(), target.numpy())\n",
    "\n",
    "        if batch_idx % log_interval == 0 or batch_idx == len(train_loader):\n",
    "            print(f\"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} \"\n",
    "                f\"({100. * batch_idx / len(train_loader):.0f}%)]\\tLoss: {loss.numpy():.2f}\")\n",
    "\n",
    "    total, correct = 0, 0\n",
    "    for data, target in test_loader:\n",
    "        predict = trainer.predict(data.numpy()) # batch_size * 10\n",
    "        total += len(data)\n",
    "        correct += np.sum(predict.numpy().argmax(1) == target.numpy().argmax(1))\n",
    "\n",
    "    print(f\"Train Epoch: {epoch} Accuracy on test dataset: {100.0 * correct / total:.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d653216f",
   "metadata": {},
   "source": [
    "### 为什么需要区分 Trainer 和 SetupTrainer？\n",
    "\n",
    "这种设计源于「编译期」与「运行期」的职责分离：\n",
    "\n",
    "1. **编译期组件**（SetupTrainer 及之前组件）：\n",
    "   - 负责构建完整的计算图（IRModule）\n",
    "   - 完成所有静态分析与优化\n",
    "   - 生成可部署的通用计算逻辑\n",
    "\n",
    "2. **运行期组件**（Trainer）：\n",
    "   - 接收编译期生成的 IRModule\n",
    "   - 管理模型参数的动态更新\n",
    "   - 维护训练过程中的临时状态\n",
    "\n",
    "这种分离架构使 TVM 能够：\n",
    "- 在服务器端完成计算图编译优化\n",
    "- 将优化后的 IRModule 部署到边缘设备\n",
    "- 在资源受限的设备上仅执行必要的参数更新\n",
    "\n",
    "这正是 TVM 实现「一次编译，到处运行」的关键设计决策。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3d4055d",
   "metadata": {},
   "source": [
    "## 方法二：使用底层训练 API\n",
    "\n",
    "我们也可以通过底层训练 API 直接构建和运行 IRModule。这些 API 主要包括：\n",
    "- 损失函数库\n",
    "- 优化器库\n",
    "- 自动微分过程\n",
    "\n",
    "### 损失函数\n",
    "\n",
    "TVM 在 `tvm.relax.training.loss` 模块中提供了丰富的损失函数实现，包括：\n",
    "- `CrossEntropyLoss`（交叉熵损失）\n",
    "- `L1Loss`（L1 损失）\n",
    "- `MSELoss`（均方误差损失）等\n",
    "\n",
    "您也可以通过继承 `tvm.relax.training.loss.Loss` 基类来自定义损失函数。\n",
    "\n",
    "损失类的实例化仅需指定超参数，其 `__call__()` 方法将接收模型输出和标签的 struct_info（结构信息），并生成对应的 Relax 损失函数："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9913b198",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55392ebf",
   "metadata": {},
   "outputs": [],
   "source": [
    "func = CrossEntropyLoss(reduction=\"sum\")(out_sinfo, label_sinfo)\n",
    "print(func)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f493d5dc",
   "metadata": {},
   "source": [
    "基于自动微分过程的技术要求，我们需要将主干函数与损失函数进行融合。为此，我们提供了 `relax.training.utils.append_loss` 工具来实现二者的融合："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d04884a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "Backbone[\"loss\"] = relax.training.utils.append_loss(Backbone[\"predict\"], func)\n",
    "Backbone.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bcd74b8",
   "metadata": {},
   "source": [
    "### 梯度计算过程\n",
    "\n",
    "为优化模型参数，我们需要计算参数的梯度。TVM 提供了自动微分转换过程 `relax.transform.Gradient` 来实现梯度计算。\n",
    "\n",
    "该自动微分（AD）系统是训练工作流的核心，基于源码转换方法实现。当前版本对输入函数有以下限制：\n",
    "1. **单数据流块限制**：函数必须仅包含一个数据流块\n",
    "2. **算子支持限制**：仅支持算术运算、元组操作等基础 Relax 算子\n",
    "\n",
    "`Gradient` 接收三个关键参数：\n",
    "- 目标函数的全局变量名\n",
    "- 需要计算梯度的参数变量\n",
    "- 输入 IRModule\n",
    "\n",
    "执行后将返回包含梯度计算逻辑的新 IRModule。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca4216fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = Backbone[\"loss\"].params[:6]\n",
    "\n",
    "Backbone = relax.transform.Gradient(\n",
    "    Backbone.get_global_var(\"loss\"),\n",
    "    require_grads=params\n",
    ")(Backbone)\n",
    "Backbone.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57d9e294",
   "metadata": {},
   "source": [
    "### 优化器\n",
    "\n",
    "TVM 在 `relax.training.optimizer` 模块中提供了多种经典优化器实现，包括：\n",
    "- 基础 SGD\n",
    "- 带动量的 SGD\n",
    "- Adam 优化器\n",
    "\n",
    "您也可以通过继承 `relax.training.optimizer.Optimizer` 基类来实现自定义优化器。\n",
    "\n",
    "优化器实例的创建仅需指定超参数（如学习率）。通过 `init()` 方法进行初始化时需传入：\n",
    "- 单个 Relax 变量 或\n",
    "- Relax 变量列表（计算图中的变量节点）\n",
    "\n",
    "该方法将完成优化器状态的初始化。初始化后，可通过以下两种方式使用优化器：\n",
    "1. 调用 `get_function()` 获取对应的 Relax 优化函数\n",
    "2. 将其关联到现有 IRModule 的计算流程中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24e05646",
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = relax.optimizer.SGD(0.1).init(params)\n",
    "Backbone[\"SGD\"] = opt.get_function()\n",
    "print(Backbone[\"SGD\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7eb8d17",
   "metadata": {},
   "source": [
    "### 训练流程\n",
    "\n",
    "完成 IRModule 的构建后，即可开始模型训练。我们需要依次执行以下操作：\n",
    "1. 对 IRModule 进行规范化处理\n",
    "2. 编译生成可执行模块\n",
    "3. 准备必要的输入数据："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a072c659",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build and legalize module\n",
    "lowered_mod = LegalizeOps()(Backbone)\n",
    "ex = relax.vm.build(lowered_mod, target=\"llvm\")\n",
    "vm = relax.VirtualMachine(ex, tvm.cpu())\n",
    "\n",
    "\n",
    "def _get_shape_as_int_list(var):\n",
    "    return [int(val) for val in var.struct_info.shape]\n",
    "\n",
    "params_list = [tvm.nd.array(np.ones(_get_shape_as_int_list(i), \"float32\")) for i in params]\n",
    "param_input_tuple = tuple_object(params_list)\n",
    "\n",
    "x_input, y_input = next(iter(train_loader))\n",
    "x_input = tvm.nd.array(x_input)\n",
    "y_input = tvm.nd.array(y_input)\n",
    "\n",
    "# The input should be (*param_input_tuple, x_input, y_input)\n",
    "# At the runtime of TVM, arguments should be TVM NDArray or TVM runtime ADT objects."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce316cb7",
   "metadata": {},
   "source": [
    "本演示仅展示单步训练过程，多步训练逻辑与此类似。\n",
    "\n",
    "**核心组件交互流程**：\n",
    "1. **伴随函数**（由自动微分过程生成）：\n",
    "   - 输入：主干网络输入 + 真实标签\n",
    "   - 输出：损失值 + 参数梯度元组\n",
    "\n",
    "2. **优化器函数**（由优化器类构建）：\n",
    "   - 输入：参数元组 + 梯度元组 + 优化器状态元组\n",
    "   - 输出：更新后的参数元组 + 新优化器状态元组\n",
    "\n",
    "通过 `opt.state` 可获取优化器状态对象，该状态包含优化过程中的关键信息：\n",
    "- 已执行的训练步数（steps）\n",
    "- 动量缓存（momentum）\n",
    "- 自适应学习率参数（如 Adam 中的一/二阶矩估计）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "103c7b9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# forward and find the gradient\n",
    "loss, param_grad_tuple = vm[\"loss_adjoint\"](*param_input_tuple, x_input, y_input)\n",
    "# update parameters\n",
    "param_input_tuple, opt.state = vm[\"SGD\"](param_input_tuple, param_grad_tuple, opt.state)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56ea863a",
   "metadata": {},
   "source": [
    "打印计算结果："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2df57dc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(loss.numpy)\n",
    "print(len(param_input_tuple), len(param_grad_tuple))\n",
    "print(param_input_tuple[0])\n",
    "print(param_grad_tuple[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
